import re
import os
import yaml
import random
import logging
import shutil
import numpy as np
import oneflow as torch
import argparse
from oasr.model import End2EndModel
from oasr.train.optim import BuildOptimizer
from oasr.train.scheduler import BuildScheduler
from oasr.train.epoch_trainner import EpochTrainer
from oasr.train.step_trainer import StepTrainer
from oasr.utils import count_parameters, find_latest_n_checkpoint
from oasr.data.loader import FeatureLoader


def main(args, params, expdir):

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True

    model_type = params['model']['type']
    model = End2EndModel[model_type](params['model'])

    # Count total parameters
    count_parameters(model.named_parameters())

    if args.ngpu >= 1:
        model.cuda()
    logging.info(model)

    optimizer = BuildOptimizer[params['train']['optimizer_type']](
        filter(lambda p: p.requires_grad, model.parameters()), **params['train']['optimizer']
    )
    logger.info('[Optimizer] Build a %s optimizer!' % params['train']['optimizer_type'])
    scheduler = BuildScheduler[params['train']['scheduler_type']](optimizer, **params['train']['scheduler'])
    logger.info('[Scheduler] Build a %s scheduler!' % params['train']['scheduler_type'])

    apply_stepwise_trainer = params['train']['stepwise'] if 'stepwise' in params else False
    Trainer = EpochTrainer if not apply_stepwise_trainer else StepTrainer
    trainer = Trainer(params, model=model, optimizer=optimizer, scheduler=scheduler, expdir=expdir, ngpu=args.ngpu,
                      parallel_mode=args.parallel_mode, local_rank=args.local_rank, is_debug=args.debug,
                      keep_last_n_chkpt=args.keep_last_n_chkpt)

    if args.continue_training:
        lastest_model_chkpt = find_latest_n_checkpoint(expdir, n=1)
        if len(lastest_model_chkpt) == 1:
            trainer.load_model(lastest_model_chkpt[0])
        
        if os.path.exists(os.path.join(expdir, 'lastest_optimizer.pt')):
            trainer.load_optimizer_state_dict(os.path.join(expdir, 'lastest_optimizer.pt'))

        trainer.auto_set_start_point(lastest_model_chkpt[0])

    train_loader = FeatureLoader(params, 'train', ngpu=args.ngpu, mode=args.parallel_mode)

    trainer.train(train_loader=train_loader)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', type=str, default='egs/aishell/conf/transformer_base.yaml')
    parser.add_argument('-n', '--ngpu', type=int, default=1)
    parser.add_argument('-g', '--gpus', type=str, default='0')
    parser.add_argument('-se', '--seed', type=int, default=1234)
    parser.add_argument('-p', '--parallel_mode', type=str, default='dp')
    parser.add_argument('-r', '--local_rank', type=int, default=0)
    parser.add_argument('-l', '--logging_level', type=str, default='info', choices=['info','debug'])
    parser.add_argument('-lg', '--log_file', type=str, default=None)
    parser.add_argument('-mp', '--mixed_precision', action='store_true', default=False)
    parser.add_argument('-ct', '--continue_training', action='store_true', default=False)
    parser.add_argument('-dir', '--expdir', type=str, default=None)
    parser.add_argument('-im', '--init_model', type=str, default=None)
    parser.add_argument('-ios', '--init_optim_state', type=str, default=None)
    parser.add_argument('-debug', '--debug', action='store_true', default=False)
    parser.add_argument('-knpt', '--keep_last_n_chkpt', type=int, default=30)
    parser.add_argument('-vb', '--verbose', type=int, default=0)
    parser.add_argument('-ol', '--opt_level', type=str, choices=['O0', 'O1', 'O2', 'O3'], default='O1')
    cmd_args = parser.parse_args()

    with open(cmd_args.config, 'r') as f:
        params = yaml.load(f, Loader=yaml.FullLoader)

    if cmd_args.expdir is not None:
        expdir = os.path.join(cmd_args.expdir, params['train']['save_name'])
    else:
        expdir = os.path.join('egs', params['data']['name'], 'exp', params['train']['save_name'])
    if not os.path.exists(expdir):
        os.makedirs(expdir)

    shutil.copy(cmd_args.config, os.path.join(expdir, 'config.yaml'))

    logging_level = {
        'info': logging.INFO,
        'debug': logging.DEBUG
    }

    if cmd_args.log_file is not None:
        log_file = cmd_args.log_file
    else:
        log_file = cmd_args.config.split('/')[-1][:-5] + '.log'
    
    LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
    logging.basicConfig(level=logging_level[cmd_args.logging_level], format=LOG_FORMAT)
    logger = logging.getLogger(__name__)

    if cmd_args.ngpu > 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(cmd_args.gpus)
        logger.info('Set CUDA_VISIBLE_DEVICES as %s' % cmd_args.gpus)

    if cmd_args.parallel_mode == 'ddp':
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        os.environ["OMP_NUM_THREADS"] = '1'
        torch.cuda.set_device(cmd_args.local_rank)

    main(cmd_args, params, expdir)
    if not cmd_args.debug:
        shutil.copy(log_file, os.path.join(expdir, 'train.log'))

